# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
from typing import *
from typing import Dict
from loguru import logger

from langchain.agents import Tool

from llm4crs.prompt import *
from llm4crs.critic import Critic
from llm4crs.agent import CRSAgent
from llm4crs.utils.prompt import CRSChatPrompt


class ToolBox:
    def __init__(self, name: str, desc: str ,tools: Dict[str, Callable[..., Any]]):
        self.name = name
        self.desc = desc
        self.tools = tools
        for t_name in self.tools.keys():
            t_name = t_name.lower()
            if 'look up' in t_name:
                self.look_up_tool_name = t_name
            if 'map' in t_name:
                self.map_tool_name = t_name

        self.failed_times = 0

    def run(self, inputs: str):
        # the inputs should be a json string generated by LLM
        # key is the tool name, value is the input string of tool
        try:
            plans: list = json.loads(inputs)
            plans = {t['tool_name']: t['input'] for t in plans}
        except Exception as e:
            return f"""An exception happens: {e}. The inputs should be a json string for tool using plan. The format should be like: "[{{'tool_name': TOOL-1, 'input': INPUT-1}}, ..., {{'tool_name': TOOL-N, 'input': INPUT-N}} ]"."""

        # check if all tool names existing
        tool_not_exist = [k for k in plans.keys() if k not in self.tools]
        if len(tool_not_exist) > 0:
            return f"These tools do not exist: {', '.join(tool_not_exist)}. Optional Tools: {', '.join(list(self.tools.keys()))}."

        res = ""
        for k, v in plans.items():
            try:
                if not isinstance(v, str):
                    v = json.dumps(v)
                output = self.tools[k].run(v)
                if ("look up" in k.lower()) or ("map" in k.lower()):
                    res += output
            except Exception as e:
                logger.debug(e)
                self.failed_times += 1
                return f"The input to tool {k} does not meet the format specification."
        if len(res) == 0:
            self.failed_times += 1
            res = "No output because tool plan is not correct. Only {lookup} and {map} would give response. "\
                  "If look up information, use {lookup}. If recommendation, use {map} in the final step. Also, remember to use ranking tool before map."
            res = res.format(lookup=self.look_up_tool_name, map=self.map_tool_name)
        return res



class CRSAgentPlanFirst(CRSAgent):

    def setup_tools(self, tools: List[Callable[..., Any]]) -> List[Tool]:
        toolbox = ToolBox('ToolExecutor', TOOLBOX_DESC, {tool.name: tool for tool in tools})
        self.toolsbox = toolbox
        tools = [Tool(name=toolbox.name, func=getattr(toolbox, 'run'), description=toolbox.desc)]
        return tools

    
    def setup_prompts(self, tools: List[Tool]):
        tools_desc = "\n".join([f"{tool.name}: {tool.desc}" for tool in self._tools])
        tool_names = "[" + ", ".join([f"{tool.name}" for tool in self._tools]) + "]"
        template = SYSTEM_PROMPT_PLAN_FIRST.format(tools_desc=tools_desc, 
                                                   tool_exe_name=self.tools[0].name, tool_names=tool_names, 
                                                   **self._tool_names, **self._domain_map)
        prompt = CRSChatPrompt(
            table_info=self.item_corups.info(),
            intermediate_steps="",
            template=template,
            tools=tools,
            input_variables=["input", "intermediate_steps"],
            buffer=self.candidate_buffer,
            memory="",
            examples="",
            reflection=""
        )
        return prompt

    
    def run(self, input: Dict[str, str], chat_history: str = None):
        self.toolsbox.failed_times = 0
        return super().run(input, chat_history)

    @property
    def failed_times(self):
        return self.toolsbox.failed_times

    